{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9e81f0d0",
   "metadata": {},
   "source": [
    "# Chapter 10 - Private Entity Resolution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "799a83e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "import private_set_intersection.python as psi\n",
    "from pandas import read_csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd0d4f79",
   "metadata": {},
   "outputs": [],
   "source": [
    "url=\"http://localhost:5000/\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46633388",
   "metadata": {},
   "source": [
    "## Client Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ddc689f",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_m = read_csv('mari_clean.csv')\n",
    "client_items = (df_m['CompanyName']+' '+df_m['Postcode']).to_list()\n",
    "\n",
    "c = psi.client.CreateWithNewKey(True)\n",
    "psirequest = c.CreateRequest(client_items).SerializeToString()\n",
    "c.CreateRequest(client_items)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "779f1690",
   "metadata": {},
   "source": [
    "### Get Server-encrypted Client Values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0c303ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "response = requests.post(url+'match', headers={'Content-Type': 'application/protobuf'}, data=psirequest)\n",
    "psiresponse = psi.Response()\n",
    "psiresponse.ParseFromString(response.content)\n",
    "psiresponse"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "563080c9",
   "metadata": {},
   "source": [
    "### Get Server Setup - Raw"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9268b455",
   "metadata": {},
   "outputs": [],
   "source": [
    "setupresponse = requests.get(url+'rawsetup')\n",
    "rawsetup = psi.ServerSetup()\n",
    "rawsetup.ParseFromString(setupresponse.content)\n",
    "rawsetup"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a6a8a32",
   "metadata": {},
   "source": [
    "## Bloom Decode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6b0ee55",
   "metadata": {},
   "outputs": [],
   "source": [
    "setupresponse = requests.get(url+'bloomsetup')\n",
    "bloomsetup = psi.ServerSetup()\n",
    "bloomsetup.ParseFromString(setupresponse.content)\n",
    "bloomsetup"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1661a146",
   "metadata": {},
   "source": [
    "### Server Calculation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34700d41",
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import ceil, log, log2\n",
    "\n",
    "fpr = 0.01\n",
    "num_client_inputs = 100\n",
    "correctedfpr = fpr/num_client_inputs\n",
    "len_server_items = 2\n",
    "max_elements = max(num_client_inputs, len_server_items)\n",
    "num_bits = (ceil(-max_elements * log2(correctedfpr) / log(2) /8 )) * 8\n",
    "num_bits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc0e9ebb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from hashlib import sha256\n",
    "\n",
    "#num_bits = len(bloomsetup.bloom_filter.bits)*8\n",
    "filterlist = ['0'] * num_bits\n",
    "\n",
    "for element in rawsetup.raw.encrypted_elements:\n",
    "    element1 = str.encode('1') + element\n",
    "    k = sha256(element1).hexdigest()\n",
    "    h1 = int(k,16) % num_bits\n",
    "\n",
    "    element2 = str.encode('2') + element\n",
    "    k = sha256(element2).hexdigest()\n",
    "    h2 = int(k,16) % num_bits\n",
    "    \n",
    "    for i in range(bloomsetup.bloom_filter.num_hash_functions):\n",
    "        pos = ((h1 + i * h2) % num_bits)\n",
    "        filterlist[num_bits-1-pos]='1'\n",
    "        \n",
    "filterstring = ''.join(filterlist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93231a78",
   "metadata": {},
   "outputs": [],
   "source": [
    "bloombits = ''.join(format(byte, '08b') for byte in reversed(bloomsetup.bloom_filter.bits))\n",
    "bloombits == filterstring"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f92ad0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_hash_functions = ceil(-log2(correctedfpr))\n",
    "num_hash_functions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae998a07",
   "metadata": {},
   "source": [
    "## GCS Decode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74fc9d19",
   "metadata": {},
   "outputs": [],
   "source": [
    "setupresponse = requests.get(url+'gcssetup')\n",
    "gcssetup = psi.ServerSetup()\n",
    "gcssetup.ParseFromString(setupresponse.content)\n",
    "gcssetup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fed9e5ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import ceil, log, log2\n",
    "\n",
    "fpr = 0.01\n",
    "num_client_inputs = 100\n",
    "correctedfpr = fpr/num_client_inputs\n",
    "\n",
    "hash_range = max_elements/correctedfpr\n",
    "hash_range"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa93d0bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from hashlib import sha256\n",
    "\n",
    "# For all server encrypted elements, calculate hash and then bucket value\n",
    "ulist = []\n",
    "for element in rawsetup.raw.encrypted_elements:\n",
    "    k = sha256(element).hexdigest()\n",
    "    ks = int(k,16) % gcssetup.gcs.hash_range\n",
    "    ulist.append(ks)\n",
    "\n",
    "# Sort the hash bucket values\n",
    "ulist.sort()\n",
    "# Calculate deltas between sorted hash bucket values \n",
    "udiff = [ulist[0]] + [ulist[n]-ulist[n-1] for n in range(1,len(ulist))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cecf3e44",
   "metadata": {},
   "outputs": [],
   "source": [
    "avg = (ulist[-1]+1)/len(ulist)\n",
    "prob = 1/avg\n",
    "gcsdiv = max(0,round(-log2(-log2(1.0-prob))))\n",
    "gcsdiv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c227b09",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For all delta hash bucket values encode as unary portion for quotient followed by binary for remainder.\n",
    "# Pad with leading zeros so binary portion is of consistent length.\n",
    "# Concatenate with previous values\n",
    "\n",
    "encoded = ''\n",
    "for diff in udiff:\n",
    "    if diff != 0:\n",
    "        quot = int(diff / pow(2,gcssetup.gcs.div)) \n",
    "        rem = diff % pow(2,gcssetup.gcs.div)\n",
    "        next = '{0:b}'.format(rem) + '1' + ('0' * quot)\n",
    "        pad = next.zfill(quot+gcssetup.gcs.div+1)\n",
    "        encoded = pad + encoded"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abda53bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Pad final encoded string with leading 0s to length as a multiple of 8 \n",
    "\n",
    "from math import ceil\n",
    "\n",
    "padlength = ceil(len(encoded)/8)*8\n",
    "padded = encoded.zfill(padlength)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "689fd90b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build gcs as concatenated sequence of bits from reversed gcs.bits value returned from setup\n",
    "# Check server gcs bits match our gcs bits \n",
    "\n",
    "gcsbits = ''.join(format(byte, '08b') for byte in reversed(gcssetup.gcs.bits))\n",
    "gcsbits == padded"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f086f6ea",
   "metadata": {},
   "source": [
    "### Calculate Set Intersection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c4cbac1",
   "metadata": {},
   "outputs": [],
   "source": [
    "intersection = c.GetIntersection(gcssetup, psiresponse)\n",
    "#intersection = c.GetIntersection(bloomsetup, psiresponse)\n",
    "#intersection = c.GetIntersection(rawsetup, psiresponse)\n",
    "\n",
    "iset = set(intersection)\n",
    "sorted(intersection)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d93f9d8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "for index in sorted(intersection):\n",
    "    print(client_items[index])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a39e921f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
